from PIL import Image, ImageDraw
import json
import os
import re
import pdb
import random
import argparse
import numpy as np
from collections import Counter
from IPython.display import display
from tqdm import tqdm
random.seed(42)  # 固定随机种子

def is_english_simple(text):
    try:
        text.encode(encoding='utf-8').decode('ascii')
    except UnicodeDecodeError:
        return False
    else:
        return True

# bbox -> point (str)
def bbox_2_point(bbox, dig=2):
    # bbox: [x1, y1, x2, y2]
    point = [(bbox[0]+bbox[2])/2, (bbox[1]+bbox[3])/2]
    point = [round(x,4) for x in point]
    return point

def normalize_click(click, size):
    x, y = click
    width, height = size
    
    x_norm = x / width
    y_norm = y / height
    return [x_norm, y_norm]

def normalize_bbox(bbox, size):
    x1, y1, x2, y2 = bbox
    width, height = size
    
    x1_norm = x1 / width
    y1_norm = y1 / height
    x2_norm = x2 / width
    y2_norm = y2 / height
    return [x1_norm, y1_norm, x2_norm, y2_norm]

def draw_point_bbox(image_path, start=None, end=None, radius=5, line=3):
    image = Image.open(image_path)
    draw = ImageDraw.Draw(image)
    width, height = image.size
    
    if start is not None:
        if len(start) == 2:
            x, y = start[0] * width, start[1] * height
            draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='red', outline='red')
        elif len(start) == 4:
            x1, y1, x2, y2 = start[0] * width, start[1] * height, start[2] * width, start[3] * height
            draw.rectangle([x1, y1, x2, y2], outline='red', width=line)
    
    if end is not None:
        if len(end) == 2:
            x, y = end[0] * width, end[1] * height
            draw.ellipse((x - radius, y - radius, x + radius, y + radius), fill='blue', outline='blue')
        elif len(end) == 4:
            x1, y1, x2, y2 = end[0] * width, end[1] * height, end[2] * width, end[3] * height
            draw.rectangle([x1, y1, x2, y2], outline='blue', width=line)
    
    return image

def check_instruction(instruction):
    # if len(instruction) > 60 or instruction.strip() == '':
    #     return False
    # if ('{' in instruction) or ('}' in instruction):
    #     return False
    if not is_english_simple(instruction):
        return False
    return True 

def get_answer(step):
    action = step['action_type'].upper()
    action_value = step['action_value']
    action_point = step['point']

    click_point = None
    type_text = None

    if action in ['CLICK', 'INPUT', 'SELECT', 'HOVER', 'TAP']:
        click_point = action_point
        if click_point is not None:
            click_point = [round(item, 2) for item in click_point]
    elif action in ['SELECT_TEXT', 'SWIPE']:
        start_point = action_point[0]
        end_point = action_point[1]
        start_point = [round(item, 2) for item in start_point]
        end_point = [round(item, 2) for item in end_point]
        click_point = [start_point, end_point]

    if action in ['INPUT', 'SELECT', 'ANSWER', 'COPY', 'SCROLL']:
        type_text = action_value
    
    answer = {'action': action.upper(), 'value': type_text, 'position': click_point}
    # if 'think' in step:
    #     final_answer = """{}\n<answer>{}</answer>""".format(step["think"], answer)
    # else:
    final_answer = answer
    return final_answer

def get_answer_by_list(step):
    if isinstance(step, dict):
        return get_answer(step)
    elif isinstance(step, list):
        tmp = [str(get_answer(item)) for item in step]
        tmp = ','.join(tmp)
        return tmp


web_single_id2path = json.load(open("Your data path/GUIAct/images/guiact/web-single/image_id2path.json", "r"))
web_multi_id2path = json.load(open("Your data path/GUIAct/images/guiact/web-multi/image_id2path.json", "r"))
smartphone_id2path = json.load(open("Your data path/GUIAct/images/guiact/smartphone/image_id2path.json", "r"))
id2path = {}
id2path.update(web_single_id2path)
id2path.update(web_multi_id2path)
id2path.update(smartphone_id2path)

data_dir = 'Your data path'

parser = argparse.ArgumentParser(description="Process data for pre-training.")
parser.add_argument("--web_imgs", default=f'{data_dir}/GUIAct/images', help="Path to the directory containing web images.")

total_step = []
total_i = 0

# web-single
args = parser.parse_args([])
args.stage = 'web-single'
args.split = 'train'
# args.split = 'test'

args.web_json = f'{data_dir}/GUIAct/{args.stage}_{args.split}_data.json'
web_train = json.load(open(args.web_json, "r"))

_WEB_SINGLE_SYSTEM = """You are an assistant trained to navigate the website screen. 
Given a task instruction, a screenshot, and a history action summary, output the think and ext action and wait for the next observation. 
The think must strictly follow these reasoning steps:
(1) Progress Estimation: Interface Comprehension and Progress Estimation
(2) Decesion Reasoning: Strategy Formulation
(3) History Summary: Update the history action summary according to the last history action summary and the action you executed.

## Action Space
1. `CLICK`: Click on an element, value is not applicable and the position [x,y] is required. 
2. `INPUT`: Type a string into an element, value is a string to type and the position [x,y] is required. 
3. `SELECT`: Select a value for an element, value is not applicable and the position [x,y] is required. 
4. `HOVER`: Hover on an element, value is not applicable and the position [x,y] is required.
5. `ANSWER`: Answer the question, value is the answer and the position is not applicable.
6. `ENTER`: Enter operation, value and position are not applicable.
7. `SCROLL`: Scroll the screen, value is the direction to scroll (up, down, right, left) and the position is not applicable.
8. `SELECT_TEXT`: Select some text content, value is not applicable and position [[x1,y1], [x2,y2]] is the start and end position of the select operation.
9. `COPY`: Copy the text, value is the text to copy and the position is not applicable.

## Output Format
<Progress Estimation>
...
</Progress Estimation>
<Decesion Reasoning>
...
</Decesion Reasoning>
<answer>
{{'action': 'ACTION_TYPE', 'value': 'element', 'position': [x,y]}}
</answer>
<History Summary>
...
</History Summary>

If value or position is not applicable, set it as `None` in <answer>.
Position might be [[x1,y1], [x2,y2]] if the action requires a start and end position.
Position represents the relative coordinates on the screenshot and should be scaled to a range of 0-1.
"""


for _, x in enumerate(tqdm(web_train)):
    if isinstance(x, dict):
        act_list = x['actions_label']
        img_uid = x['image_id']
        img_url = os.path.join(args.web_imgs, img_uid+'.png')
        image_path = id2path[img_uid]
        # image = Image.open(img_url)
        # img_size = image.size
        # img_array = np.array(image)
        # uniform = np.all(img_array == img_array[0, 0])
        # if uniform:
        #     print(img_uid)
        #     continue
        
        img_size = [x['image_size']['width'], x['image_size']['height']]
        
        step_history = []
        task = x['question']
        if not check_instruction(task):
            print(task)
            continue
        instruction = x['thoughts']
        
        # click - bbox
        # input - bbox + text
        x_act = []
        previous_actions = []
        for i, xi in enumerate(act_list):
            action_type = xi['name'].lower()
            bbox = None
            point = None
            action_value = None
            
            if action_type in ['click', 'input', 'select', 'hover']:
                bbox_str = xi['element']['related']
                bbox_val = re.findall(r'\d+\.\d+', bbox_str)
                bbox = [float(x) for x in bbox_val]
                # bbox = normalize_bbox(bbox, img_size)
                point = bbox_2_point(bbox)
                # img_draw = draw_point_bbox(img_url, point)
                # print(task)
                # display(img_draw)

            if action_type in ['input', 'select', 'answer']:
                action_value = xi['text']

            if action_type in ['scroll']:
                down = float(xi['scroll']['related']['down'])
                right = float(xi['scroll']['related']['right'])
                if down > 0:
                    action_value = 'down'
                elif down < 0:
                    action_value = 'up'
                    
                if right > 0:
                    action_value = 'right'
                elif right < 0:
                    action_value = 'left'

            xi['action_type'] = action_type
            xi['action_value'] = action_value
            if action_value is not None and not check_instruction(action_value):
                continue
            xi['point'] = point
            x_act.append(xi)
            # xi['img_url'] = img_uid
        
        if len(x_act) != 1:
            continue

        previous_step = ""
        for i, action in enumerate(previous_actions):
            previous_step += 'Step' + str(i) + ', previous action: ' + action[:-1] + "}. "

        answer_dict = get_answer_by_list(x_act[0])
        cur_answer = str(answer_dict)
        

        total_step.append({
            "id": "guiact_web_single_{}".format(total_i),
            "step_id": total_i,
            "image": image_path,
            "problem": _WEB_SINGLE_SYSTEM,
            "solution": cur_answer,
            "task": task,
            "history": previous_step,
            "bbox_ref": bbox,
            "is_last": True,
            "is_first": True,
            })
        previous_actions.append(cur_answer)
        total_i += 1

args = parser.parse_args([])
args.stage = 'web-multi'
args.split = 'train'
# args.split = 'test'

args.web_json = f'{data_dir}/GUIAct/{args.stage}_{args.split}_data.json'
web_train = json.load(open(args.web_json, "r"))

_WEB_MULTI_SYSTEM = """You are an assistant trained to navigate the website screen. 
Given a task instruction, a screenshot, and a history action summary, output the think and ext action and wait for the next observation. 
The think must strictly follow these reasoning steps:
(1) Progress Estimation: Interface Comprehension and Progress Estimation
(2) Decesion Reasoning: Strategy Formulation
(3) History Summary: Update the history action summary according to the last history action summary and the action you executed.


## Action Space
1. `CLICK`: Click on an element, value is not applicable and the position [x,y] is required. 
2. `INPUT`: Type a string into an element, value is a string to type and the position [x,y] is required. 
3. `SELECT`: Select a value for an element, value is not applicable and the position [x,y] is required. 
4. `HOVER`: Hover on an element, value is not applicable and the position [x,y] is required.
5. `ANSWER`: Answer the question, value is the answer and the position is not applicable.
6. `ENTER`: Enter operation, value and position are not applicable.
7. `SCROLL`: Scroll the screen, value is the direction to scroll (up, down, right, left) and the position is not applicable.
8. `SELECT_TEXT`: Select some text content, value is not applicable and position [[x1,y1], [x2,y2]] is the start and end position of the select operation.
9. `COPY`: Copy the text, value is the text to copy and the position is not applicable.

## Output Format
<Progress Estimation>
...
</Progress Estimation>
<Decesion Reasoning>
...
</Decesion Reasoning>
<answer>
{{'action': 'ACTION_TYPE', 'value': 'element', 'position': [x,y]}}
</answer>
<History Summary>
...
</History Summary>

If value or position is not applicable, set it as `None` in <answer>.
Position might be [[x1,y1], [x2,y2]] if the action requires a start and end position.
Position represents the relative coordinates on the screenshot and should be scaled to a range of 0-1.
"""

# total_step = []
# total_i = 0
count_web = {}
web_multi_think = []
for now_id, x in enumerate(tqdm(web_train)):
    if isinstance(x, dict):
        act_list = x['actions_label']
        img_uid = x['image_id']
        image_path = id2path[img_uid]
        # img_url = os.path.join(args.web_imgs, img_uid+'.png')
        # image = Image.open(img_url)
        # img_size = image.size
        # img_array = np.array(image)
        # uniform = np.all(img_array == img_array[0, 0])
        # if uniform:
        #     print(img_uid)
        #     continue
        
        img_size = [x['image_size']['width'], x['image_size']['height']]
        
        step_history = []
        task = x['question']
        instruction = x['thoughts']

        if not check_instruction(task):
            print(task)
            continue
        
        # click - bbox
        # select - [click1, click2]
        # input - bbox + text
        count_web[len(act_list)] = count_web.get(len(act_list), 0) + 1
        sub = []
        if int(x['uid'].split('_')[-1]) == 0:
            previous_actions = []
        for i, xi in enumerate(act_list):
            action_type = xi['name'].lower()
            bbox = None
            point = None
            action_value = None
            
            if action_type in ['select_text']:
                start = xi['dual_point']['related']['from']
                start = re.findall(r'\d+\.\d+', start)
                start = [float(x) for x in start]
    
                end = xi['dual_point']['related']['to']
                end = re.findall(r'\d+\.\d+', end)
                end = [float(x) for x in end]
                point = [start, end]
            
            if action_type in ['click', 'input', 'select', 'hover']:
                bbox_str = xi['element']['related']
                bbox_val = re.findall(r'\d+\.\d+', bbox_str)
                bbox = [float(x) for x in bbox_val]
                # bbox = normalize_bbox(bbox, img_size)
                point = bbox_2_point(bbox)
                # img_draw = draw_point_bbox(img_url, point)
                # print(task)
                # display(img_draw)

            if action_type in ['input', 'select', 'answer', 'copy']:
                action_value = xi['text']

            if action_type in ['scroll']:
                down = float(xi['scroll']['related']['down'])
                right = float(xi['scroll']['related']['right'])
                if down > 0:
                    action_value = 'down'
                elif down < 0:
                    action_value = 'up'
                    
                if right > 0:
                    action_value = 'right'
                elif right < 0:
                    action_value = 'left'

            xi['action_type'] = action_type
            xi['action_value'] = action_value
            if action_value is not None and not check_instruction(action_value):
                continue
            
            xi['point'] = point

            previous_step = ""
            for j, action in enumerate(previous_actions[-4:]):
                previous_step += 'Step' + str(j) + ', previous action: ' + action[:-1] + "}. "

            answer_dict = get_answer_by_list(xi)
            cur_answer = str(answer_dict)

            if now_id == len(web_train) - 1 and i == len(act_list) - 1:
                is_last = True
            else:
                is_last = (i == len(act_list)-1 and int(x['uid'].split('_')[-1]) != int(web_train[now_id+1]['uid'].split('_')[-1])-1)
            total_step.append({
                "id": "guiact_web_multi_{}".format(total_i),
                "step_id": total_i,
                "image": image_path,
                "problem": _WEB_MULTI_SYSTEM,
                "solution": cur_answer,
                "task": task,
                "history": previous_step,
                "bbox_ref": bbox,
                "is_last": is_last,
                "is_first": int(x['uid'].split('_')[-1]) == 0,
                })
            
            previous_actions.append(cur_answer)
            total_i += 1



args = parser.parse_args([])
args.stage = 'smartphone'
args.split = 'train'
# args.split = 'test'

args.web_json = f'{data_dir}/GUIAct/{args.stage}_{args.split}_data.json'
web_train = json.load(open(args.web_json, "r"))

_SMARTPHONE_SYSTEM = """You are an assistant trained to navigate the smartphone screen. 
Given a task instruction, a screenshot, and a history action summary, output the think and ext action and wait for the next observation. 
The think must strictly follow these reasoning steps:
(1) Progress Estimation: Interface Comprehension and Progress Estimation
(2) Decesion Reasoning: Strategy Formulation
(3) History Summary: Update the history action summary according to the last history action summary and the action you executed.


## Action Space
1. `INPUT`: Type a string into an element, value is a string to type and the position [x,y] is not applicable.
2. `SWIPE`: Swipe the screen, value is not applicable and the position [[x1,y1], [x2,y2]] is the start and end position of the swipe operation.
3. `TAP`: Tap on an element, value is not applicable and the position [x,y] is required.
4. `ANSWER`: Answer the question, value is the status (e.g., 'task complete') and the position is not applicable.
5. `ENTER`: Enter operation, value and position are not applicable.

## Output Format
<Progress Estimation>
...
</Progress Estimation>
<Decesion Reasoning>
...
</Decesion Reasoning>
<answer>
{{'action': 'ACTION_TYPE', 'value': 'element', 'position': [x,y]}}
</answer>
<History Summary>
...
</History Summary>

If value or position is not applicable, set it as `None` in <answer>.
Position might be [[x1,y1], [x2,y2]] if the action requires a start and end position.
Position represents the relative coordinates on the screenshot and should be scaled to a range of 0-1.
"""

# total_step = []
# total_i = 0
count_mobile = {}
step_history = []
mobile_think = []
sub = []

past_group_id = None
for now_id, x in enumerate(tqdm(web_train)):
    if isinstance(x, dict):
        act_list = x['actions_label']
        img_uid = x['image_id']
        img_url = os.path.join(args.web_imgs, img_uid+'.png')                
        img_size = [x['image_size']['width'], x['image_size']['height']]
        image_path = id2path[img_uid]
        
        task = x['question']
        instruction = x['thoughts']

        if not check_instruction(task):
            print(task)
            continue
        
        group_id = '_'.join(img_uid.split('_')[:-2])
        if group_id != past_group_id:
            count_mobile[len(step_history)] = count_mobile.get(len(step_history), 0) + 1
            if len(sub) > 1:
                mobile_think.append(sub)
            step_history = []
            sub = []
            previous_actions = []

        if True:
            xi = act_list
            action_type = xi['name'].lower()
            bbox = None
            point = None
            action_value = None
            
            if action_type in ['swipe']:
                start = xi['dual_point']['related']['from']
                start = re.findall(r'\d+\.\d+', start)
                start = [float(x) for x in start]
    
                end = xi['dual_point']['related']['to']
                end = re.findall(r'\d+\.\d+', end)
                end = [float(x) for x in end]
                point = [start, end]
            
            if action_type in ['tap']:
                point_str = xi['point']['related']
                point_val = re.findall(r'\d+\.\d+', point_str)
                point = [float(x) for x in point_val]
                
            if action_type in ['input', 'answer']:
                action_value = xi['text']

            xi['action_type'] = action_type
            xi['action_value'] = action_value
            # if action_value is not None and not check_instruction(action_value):
            #     continue
            
            xi['point'] = point
            xi['img_url'] = img_uid

            previous_step = ""
            for j, action in enumerate(previous_actions[-4:]):
                previous_step += 'Step' + str(j) + ', previous action: ' + action[:-1] + "}. "

            answer_dict = get_answer_by_list(xi)
            cur_answer = str(answer_dict)

            if now_id == len(web_train) - 1:
                is_last = True
            else:
                is_last = ('_'.join(web_train[now_id+1]['image_id'].split('_')[:-2]) != group_id)
            total_step.append({
                "id": "guiact_smartphone_{}".format(total_i),
                "step_id": total_i,
                "image": image_path,
                "problem": _SMARTPHONE_SYSTEM,
                "solution": cur_answer,
                "task": task,
                "history": previous_step,
                "bbox_ref": bbox,
                "is_last": is_last,
                "is_first": int(x['uid'].split('_')[-1]) == 0,
                })
            
            # import ipdb; ipdb.set_trace()

            previous_actions.append(cur_answer)
            total_i += 1

import jsonlines  
save_url = "Your save path"
with jsonlines.open(save_url, mode="w") as writer:
    writer.write_all(total_step)